/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file matmul_impl.h
 * \brief
 */
#ifndef IMPL_MATMUL_MODULES_PARAMS_H
#define IMPL_MATMUL_MODULES_PARAMS_H

#include "kernel_macros.h"

#include "lib/matmul/tiling.h"
#include "kernel_operator.h"
#include "../matmul_utils.h"
#include "../matmul_constant_tiling_impl.h"
#include "matmul_type_def.h"
#include "resource/cube_in_buffer/global_cache.h"

namespace matmul {
/* **************************************************************************************************
 * MatmulParamsBase                                             *
 * ************************************************************************************************* */
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsBase {
    __aicore__ inline MatmulParamsBase() {};
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsNorm : public MatmulParamsBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> {
    using L0cT = typename GetDstType<typename A_TYPE::T>::Type;
    __aicore__ inline MatmulParamsNorm() {};
    using SrcT = typename A_TYPE::T;
    using SrcBT = typename B_TYPE::T;
    using DstT = typename C_TYPE::T;
    using BiasT = typename BIAS_TYPE::T;
    TQue<TPosition::C1, QUEUE_DEPTH> qidBias_;
#if __CCE_AICORE__ < 200
    TQue<TPosition::A2, QUEUE_DEPTH> qidA2_;
    TQue<TPosition::B2, QUEUE_DEPTH> qidB2_;
#endif

    LocalTensor<BiasT> cacheHeadBias_; // Allocate and release using qidBias_

    SrcT aScalar_;
    SrcBT bScalar_;
    DEBUG_CODE(int calCount_ = 0);

    TBuffAddr leftMatrix_;
    TBuffAddr rightMatrix_;
    TBuffAddr inputBias_;

    __gm__ SrcT* aGlobal_;
    __gm__ SrcBT* bGlobal_;
    __gm__ BiasT* biasGlobal_;

    TPipe* tpipe_;
    MatmulTiling<MM_CFG> tiling_;
    __gm__ uint8_t* cacheWorkspaceAddr;

#if __CCE_AICORE__ < 220
    __ubuf__ uint8_t* cacheUBWorkspaceAddr = nullptr;
    LocalTensor<uint8_t> localWorkspace;
    int nd2nz0ffset = 0;
    int transOffset = 0;
    int co2Offset = 0;
#endif

    int singleCoreM_;
    int singleCoreN_;
    int singleCoreK_;
    // iterate nums in mnk axis
    int mIter_;
    int nIter_;
    int kIter_;

    // baseUseX_ is the same as baseX in most cases, while it will be smaller than baseX when dealing with tail cases
    // measured in element
    int baseUseM_;
    int baseUseK_;
    int baseUseN_;
    // measured in cube block
    int blockUseM_;
    int blockUseK_;
    int blockUseN_;

    bool isFirstIter_;
    bool isTransposeA_; // whether A matrix need to transpose
    bool isTransposeB_; // whether B matrix need to transpose
    // whether enbale bias, default value is false
    bool enableBias_;

    int tailM_, tailK_, tailN_;
    // current c matrix coordinate
    int curM_, curN_;
    // current c matrix step size, there could be tail steps
    int curStepM_, curStepN_;
    // current c matrix step block coordinate
    int stepMIdx_, stepNIdx_;

    bool enHF32Mode_;
    int32_t hf32TransMode_;
    uint8_t subBlockIdx_;

    int baseMN_;

    uint64_t dataPtr_;
    uint64_t tilingPtr_;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsNormQuant : public MatmulParamsNorm<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> {
    __aicore__ inline MatmulParamsNormQuant() {};
    uint64_t quantScalar_ = 0;
    GlobalTensor<uint64_t> quantTensor_;
    // 0: no quant, 1: deqf16, 2: vdeqf16, 3: QF322B8_PRE, 4: VQF322B8_PRE, 5: REQ8(s32->u8/s8), 6: VREQ8(s32->u8/s8)
    uint8_t quantMode_ = 0;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsNormOuterProduct : public MatmulParamsNorm<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> {
    __aicore__ inline MatmulParamsNormOuterProduct() {};
    int sMadMStep_ = 0;
    int sMadNStep_ = 0;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsNormQuantOuterProduct : public MatmulParamsNormQuant<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> {
    __aicore__ inline MatmulParamsNormQuantOuterProduct() {};
    int sMadMStep_ = 0;
    int sMadNStep_ = 0;
};


template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsMDL : public MatmulParamsBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> {
    using L0cT = typename GetDstType<typename A_TYPE::T>::Type;
    __aicore__ inline MatmulParamsMDL() {};
    using SrcT = typename A_TYPE::T;
    using SrcBT = typename B_TYPE::T;
    using DstT = typename C_TYPE::T;
    using BiasT = typename BIAS_TYPE::T;

    TQue<TPosition::C1, QUEUE_DEPTH> qidBias_;
#if __CCE_AICORE__ < 200
    TQue<TPosition::A2, QUEUE_DEPTH> qidA2_;
    TQue<TPosition::B2, QUEUE_DEPTH> qidB2_;
#endif

    DEBUG_CODE(int calCount_ = 0);

    TBuffAddr leftMatrix_;
    TBuffAddr rightMatrix_;
    TBuffAddr inputBias_;

    __gm__ SrcT* aGlobal_;
    __gm__ SrcBT* bGlobal_;
    __gm__ BiasT* biasGlobal_;

    TPipe* tpipe_;
    MatmulTiling<MM_CFG> tiling_;
    __gm__ uint8_t* cacheWorkspaceAddr;

#if __CCE_AICORE__ < 220
    __ubuf__ uint8_t* cacheUBWorkspaceAddr = nullptr;
    LocalTensor<uint8_t> localWorkspace;
    int nd2nz0ffset = 0;
    int transOffset = 0;
    int co2Offset = 0;
#endif

    int singleCoreM_;
    int singleCoreN_;
    int singleCoreK_;
    // iterate nums in mnk axis
    int mIter_;
    int nIter_;
    int kIter_;
    // iterate nums in mn step axis
    int mStepIter_;
    int nStepIter_;
    int kaStepIter_;
    int kbStepIter_;
    int kStepIter_;
    int minStepK_;
    int kaStepFactor_;
    int kbStepFactor_;

    // baseUseX_ is the same as baseX in most cases, while it will be smaller than baseX when dealing with tail cases
    // in unit of element
    int baseUseM_;
    int baseUseK_;
    int baseUseN_;
    // in unit of cube block
    int blockUseM_;
    int blockUseK_;
    int blockUseN_;

    // in unit of element
    int baseUseStepM_;
    int baseUseStepN_;
    int baseUseStepKa_;
    int baseUseStepKb_;
    // in unit of cube block
    int blockUseStepM_;
    int blockUseStepN_;
    int blockUseStepKa_;
    int blockUseStepKb_;

    bool isFirstIter_;
    bool isTransposeA_; // whether A matrix need to transpose
    bool isTransposeB_; // whether B matrix need to transpose
    // whether enbale bias, default value is false
    bool enableBias_;

    // in unit of element
    int tailM_, tailK_, tailN_;
    // in unit of element
    int tailStepM_, tailStepN_, tailStepKa_, tailStepKb_;
    // current c matrix coordinate, in unit of baseMN
    int curM_, curN_;
    // current c matrix step size, in unit of baseMNK , there could be tail steps
    int curStepM_, curStepN_;
    // current c matrix step block coordinate, in unit of stepMNK
    int stepMIdx_, stepNIdx_, stepKaIdx_, stepKbIdx_;

    // stepKa == kIter
    bool isA1KFullLoad_, isB1KFullLoad_;

    bool enHF32Mode_;
    int32_t hf32TransMode_;
    uint8_t subBlockIdx_;

    int baseMN_;
    int cacheA1Factor_, cacheB1Factor_;
    uint64_t quantScalar_ = 0;
#if __CCE_AICORE__ >= 220
    int sMadMStep_ = 0;
    int sMadNStep_ = 0;
#endif
    uint64_t dataPtr_;
    uint64_t tilingPtr_;
    GlobalTensor<uint64_t> quantTensor_;
    // 0: no quant, 1: deqf16, 2: vdeqf16;
    uint8_t quantMode_ = 0;
    // anti quant param.
    SrcT antiQuantOffsetScalar_;
    SrcT antiQuantScaleScalar_;
    LocalTensor<SrcT> antiQuantOffsetTensor_;
    LocalTensor<SrcT> antiQuantScaleTensor_;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsBasicBlock : public MatmulParamsNorm<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> {
    __aicore__ inline MatmulParamsBasicBlock() {};
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParamsIBShareNorm : public MatmulParamsBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG> {
    using L0cT = typename GetDstType<typename A_TYPE::T>::Type;
    __aicore__ inline MatmulParamsIBShareNorm() {};
    using SrcT = typename A_TYPE::T;
    using SrcBT = typename B_TYPE::T;
    using DstT = typename C_TYPE::T;
    using BiasT = typename BIAS_TYPE::T;
    TQue<TPosition::C1, QUEUE_DEPTH> qidBias_;

#if __CCE_AICORE__ < 200
    TQue<TPosition::A2, QUEUE_DEPTH> qidA2_;
    TQue<TPosition::B2, QUEUE_DEPTH> qidB2_;
#endif

    LocalTensor<BiasT> cacheHeadBias_; // Allocate and release using qidBias_

    SrcT aScalar_;
    SrcBT bScalar_;
    DEBUG_CODE(int calCount_ = 0);

    TBuffAddr leftMatrix_;
    TBuffAddr rightMatrix_;
    TBuffAddr inputBias_;

    __gm__ SrcT* aGlobal_;
    __gm__ SrcBT* bGlobal_;
    __gm__ BiasT* biasGlobal_;

    TPipe* tpipe_;
    MatmulTiling<MM_CFG> tiling_;
    __gm__ uint8_t* cacheWorkspaceAddr;

    int singleCoreM_;
    int singleCoreN_;
    int singleCoreK_;
    // iterate nums in mnk axis
    int mIter_;
    int nIter_;
    int kIter_;

    // baseUseX_ is the same as baseX in most cases, while it will be smaller than baseX when dealing with tail cases
    // measured in element
    int baseUseM_;
    int baseUseK_;
    int baseUseN_;
    // measured in cube block
    int blockUseM_;
    int blockUseK_;
    int blockUseN_;

    bool isFirstIter_;
    bool isTransposeA_; // whether A matrix need to transpose
    bool isTransposeB_; // whether B matrix need to transpose
    // whether enbale bias, default value is false
    bool enableBias_;

    int tailM_, tailK_, tailN_;
    // current c matrix coordinate
    int curM_, curN_;
    // current c matrix step size, there could be tail steps
    int curStepM_, curStepN_;
    // current c matrix step block coordinate
    int stepMIdx_, stepNIdx_;

    bool enHF32Mode_;
    int32_t hf32TransMode_;
    uint8_t subBlockIdx_;

    int baseMN_;

    uint64_t dataPtr_;
    uint64_t tilingPtr_;
};

/* **************************************************************************************************
 * MatmulParams                                             *
 * ************************************************************************************************* */
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG, MatmulVersion MM_VER,
    class ENABLE_QUANT = void>
struct MatmulParams {
    __aicore__ inline MatmulParams(){};
};

// CFG_NORM
#if __CCE_AICORE__ >= 220
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_NORM),
    typename std::enable_if<!(((
        IsSameType<typename A_TYPE::T, int8_t>::value ||
        IsSameType<typename A_TYPE::T, int4b_t>::value) &&
        IsSameType<typename C_TYPE::T, half>::value) ||
        (IsSameType<typename A_TYPE::T, int8_t>::value &&
        (IsSameType<typename C_TYPE::T, int8_t>::value ||
        IsSameType<typename C_TYPE::T, uint8_t>::value))
#if __CCE_AICORE__ == 220
        || ((IsSameType<typename A_TYPE::T, half>::value ||
        IsSameType<typename A_TYPE::T, bfloat16_t>::value) &&
        IsSameType<typename C_TYPE::T, int8_t>::value)
#endif
        ) && ToMatmulConfig(MM_CFG).scheduleType != ScheduleType::OUTER_PRODUCT>::type> {
    __aicore__ inline MatmulParams(){};
    using PARAMS = MatmulParamsNorm<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_NORM),
    typename std::enable_if<!(((
        IsSameType<typename A_TYPE::T, int8_t>::value ||
        IsSameType<typename A_TYPE::T, int4b_t>::value) &&
        IsSameType<typename C_TYPE::T, half>::value) ||
        (IsSameType<typename A_TYPE::T, int8_t>::value &&
        (IsSameType<typename C_TYPE::T, int8_t>::value ||
        IsSameType<typename C_TYPE::T, uint8_t>::value))
#if __CCE_AICORE__ == 220
        || ((IsSameType<typename A_TYPE::T, half>::value ||
        IsSameType<typename A_TYPE::T, bfloat16_t>::value) &&
        IsSameType<typename C_TYPE::T, int8_t>::value)
#endif
        ) && ToMatmulConfig(MM_CFG).scheduleType == ScheduleType::OUTER_PRODUCT>::type> {
    __aicore__ inline MatmulParams(){};
    using PARAMS = MatmulParamsNormOuterProduct<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};
#else
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_NORM),
    typename std::enable_if<!(
        (IsSameType<typename A_TYPE::T, int8_t>::value && IsSameType<typename C_TYPE::T, half>::value) ||
        (IsSameType<typename A_TYPE::T, int8_t>::value && IsSameType<typename C_TYPE::T, int8_t>::value))>::type> {
    __aicore__ inline MatmulParams(){};
    using PARAMS = MatmulParamsNorm<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};
#endif

#if __CCE_AICORE__ >= 220
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_NORM),
    typename std::enable_if<(((
        IsSameType<typename A_TYPE::T, int8_t>::value ||
        IsSameType<typename A_TYPE::T, int4b_t>::value) &&
        IsSameType<typename C_TYPE::T, half>::value) ||
        (IsSameType<typename A_TYPE::T, int8_t>::value &&
        (IsSameType<typename C_TYPE::T, int8_t>::value ||
        IsSameType<typename C_TYPE::T, uint8_t>::value))
#if __CCE_AICORE__ == 220
        || ((IsSameType<typename A_TYPE::T, half>::value ||
        IsSameType<typename A_TYPE::T, bfloat16_t>::value) &&
        IsSameType<typename C_TYPE::T, int8_t>::value)
#endif        
        ) && ToMatmulConfig(MM_CFG).scheduleType != ScheduleType::OUTER_PRODUCT>::type> {
    __aicore__ inline MatmulParams(){};
    using PARAMS = MatmulParamsNormQuant<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};

template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_NORM),
    typename std::enable_if<(((
        IsSameType<typename A_TYPE::T, int8_t>::value ||
        IsSameType<typename A_TYPE::T, int4b_t>::value) &&
        IsSameType<typename C_TYPE::T, half>::value) ||
        (IsSameType<typename A_TYPE::T, int8_t>::value &&
        (IsSameType<typename C_TYPE::T, int8_t>::value ||
        IsSameType<typename C_TYPE::T, uint8_t>::value))
#if __CCE_AICORE__ == 220
        || ((IsSameType<typename A_TYPE::T, half>::value ||
        IsSameType<typename A_TYPE::T, bfloat16_t>::value) &&
        IsSameType<typename C_TYPE::T, int8_t>::value)
#endif        
        ) && ToMatmulConfig(MM_CFG).scheduleType == ScheduleType::OUTER_PRODUCT>::type> {
    __aicore__ inline MatmulParams(){};
    using PARAMS = MatmulParamsNormQuantOuterProduct<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};
#else
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_NORM),
    typename std::enable_if<(
        (IsSameType<typename A_TYPE::T, int8_t>::value && IsSameType<typename C_TYPE::T, half>::value) ||
        (IsSameType<typename A_TYPE::T, int8_t>::value && IsSameType<typename C_TYPE::T, int8_t>::value))>::type> {
    __aicore__ inline MatmulParams(){};
    using PARAMS = MatmulParamsNormQuant<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};
#endif

// CFG_MDL
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_MDL)> {
    __aicore__ inline MatmulParams() {};
    using PARAMS = MatmulParamsMDL<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};

// MM_CFG_BB
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(MM_CFG_BB)> {
    __aicore__ inline MatmulParams() {};
    using PARAMS = MatmulParamsBasicBlock<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};

// CFG_IBSHARE_NORM
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
struct MatmulParams<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, GetMatmulVersion(CFG_IBSHARE_NORM)> {
    __aicore__ inline MatmulParams() {};
    using PARAMS = MatmulParamsIBShareNorm<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
};

}

#endif